import torch
from torch import nn


if __name__ == '__main__':
    in_features = 8
    out_features = 4
    batch_size = 2
    seq_length = 6
    fc = nn.Linear(in_features, out_features)
    
    X = torch.randn((batch_size, in_features))
    print('X:', X.size())  # (batch_size, in_features)
    output = fc(X)
    print('output:', output.size())  # (batch_size, out_features)
    
    X = torch.randn((batch_size, seq_length, in_features))
    print('X:', X.size())  # (batch_size, seq_length, in_features)
    output = fc(X)
    print('output:', output.size())  # (batch_size, seq_length, out_features)
    